####################################################
#Author: Kelli Marquardt
#Purpose: Produce various tables/figures for main paper (using patient data)

# Inputs:
#-  data/est_dat_fake.csv 

# Outputs:
#-  output/figures/fig_3.png, fig_c2.png
#-  output/tables/tab_2.txt, tab_3.txt, tab_4.txt
#-  output/tables/tab_a1.txt, tab_a8.txt
#-  output/tables/tab_c2.txt, tab_c3.txt

####################################################


############################
#0 load required packages
############################
rm(list = ls(all.names = TRUE))

#load packages
library(dplyr) 
library(fastDummies)
library(ggplot2)
library(sandwich)
library(lmtest)
library(tidyr) #for pivot wider
library(scales) #for 'nice' graph limits



############################
#0 Define functions used in code 
############################

save_tex_table = function(header, footer, body, file_name) {
  hold = paste(header, body, footer, sep = "")
  write(hold, file = file.path("..", "output", "tables", paste0(file_name, ".txt")))
}

#################################################
#Step 1: read in data and separate into step1_dat and step2_dat
#################################################

est_dat = read.csv(file.path("..", "data", "est_dat_fake.csv"), stringsAsFactors = FALSE)

step1_dat=est_dat%>%
  select(pat_id, male, hispanic, white, Med, Com, birth_year, 
         first_pcp_id, Qi, xi, 
         ends_with("_uptoQ1"))

step2_dat=est_dat%>%
  select(pat_id, male, Qi, xi, Di, type2_max)

#################################################
#Step 2: Produce figure 3
#################################################

###
#make density plot of xi for male and female 

#first get min and max values 
x_vec=step2_dat$xi[!is.na(step2_dat$xi)]
x_min=min(x_vec)
x_max=max(x_vec)
eq_breaks=extended_breaks(n=6)(x_vec)
x_min=min(min(eq_breaks), x_min)
x_max=max(max(eq_breaks), x_max)

###
#create the plot 
match_dist=step2_dat %>%
  filter(!is.na(xi)) %>%
  mutate(sex = factor(ifelse(male == 1, "Male", "Female"), levels = c("Male", "Female"))) %>%
  ggplot() + 
  stat_density(aes(x = xi, color = sex, linetype = sex),
               geom = "line", position = "identity", linewidth = 1) +
  scale_linetype_manual(values = c("Male" = "dashed", "Female" = "solid")) +
  scale_color_manual(values = c("Male" = "cyan4", "Female" = "orange2")) +
  labs(x = "Observed ADHD Risk Signal", y = "") +
  theme_classic() +
  theme(
    legend.position = "bottom",
    legend.title = element_blank(),
    legend.text = element_text(size = 16),
    axis.text = element_text(size = 16),
    axis.title = element_text(size = 16),
      legend.key.width=unit(15, "mm"),
    legend.key.spacing = unit(5,"mm")
  ) +
  scale_x_continuous(
    limits = c(x_min, x_max),
    breaks = eq_breaks
  )
#match_dist
ggsave(plot = match_dist,
       file = file.path("..", "output", "figures", "fig_3.png"),
       height = 5.8,
       width = 8)

#clean up 
rm(match_dist, x_min, x_max, eq_breaks, x_vec)


#################################################
#Step 3: Create MH statistics and comparison table (table 4)
#################################################

mh_stats=step2_dat%>%
  mutate(Di_condit=ifelse(Qi==1, Di, NA))%>%
  select(male,Di, Qi, Di_condit, xi)

mh_stats_table=data.frame(stat = character(), total=numeric(), male=numeric(), female=numeric(), 
                          difference=numeric(), pval_dif=numeric(),
                          stringsAsFactors = F)

#loop overall columns in mh_stats except male 
for (var in names(mh_stats)) {
  if (!var %in% "male") {
    group0 = mh_stats[[var]]
    group1 = mh_stats[[var]][mh_stats[["male"]] == 1]
    group2 = mh_stats[[var]][mh_stats[["male"]] == 0]
    
    #get means and sd 
    means=c(mean(group0, na.rm = T), mean(group1, na.rm = T), mean(group2, na.rm = T))
    sds=c(sd(group0, na.rm = T), sd(group1, na.rm = T), sd(group2, na.rm = T))
    
    # Perform t-test
    test_result = t.test(group1, group2, var.equal = TRUE)
    
    # Store the variable name, means, sd, and ttest p-value in the results dataframe
    mh_stats_table = rbind(mh_stats_table, data.frame(stat = paste0(var, "_mean"), 
                                           total=means[1], male=means[2], female=means[3],
                                           difference= test_result$estimate[1]-test_result$estimate[2],
                                           pval_dif = test_result$p.value))
    mh_stats_table = rbind(mh_stats_table, data.frame(stat = paste0(var, "_sd"), 
                                                       total=sds[1], male=sds[2], female=sds[3],
                                                       difference= test_result$estimate[1]-test_result$estimate[2],
                                                       pval_dif = test_result$p.value))
    
  }
}
  
#put string together, and label rows as stat
mh_stats_table =mh_stats_table %>%
  mutate(difference= round(difference, 4), 
         stars=case_when(pval_dif<.01 ~ "***",
                         pval_dif<.05 ~ "**",
                         pval_dif<.1 ~ "*",
                          T~""))%>%
  mutate(dif=paste0(difference,stars,""))
rownames(mh_stats_table)=mh_stats_table$stat

#get vector of N's and add to table 
N_mh_stat_panelA=c(nrow(step2_dat), sum(step2_dat$male))
N_mh_stat_panelA[3]=N_mh_stat_panelA[1]-N_mh_stat_panelA[2]
N_mh_stat_panelB=c(sum(step2_dat$Qi), sum(step2_dat$Qi*step2_dat$male))
N_mh_stat_panelB[3]=N_mh_stat_panelB[1]-N_mh_stat_panelB[2]


#rename variables and keep only needed columns
mh_stats_table=mh_stats_table%>%
  select(stat, total, male, female, dif)%>%
  mutate(stat_name=case_when(stat=="Di_mean" ~ "\\hspace{3mm}ADHD Dx. ",
                            stat=="Qi_mean" ~ "\\hspace{3mm}Behav. Appt. $(Q_i)$ ",
                            stat=="Di_condit_mean" ~ "\\hspace{3mm}ADHD Dx. ",
                            stat=="xi_mean" ~ "\\hspace{3mm}ADHD Match Signal ($x_i$) ",
                            T ~ " & "))


#produce the inside of the table 
tab_hold=""
for(i in 1:nrow(mh_stats_table)){
  if(mh_stats_table$stat[i] %in% c("Di_mean","Qi_mean", "xi_mean")){
    
    new=sprintf(" %s &  %.3f & %.3f & %.3f &  \\multirow{2}{*}{%s} \\\\ \n", mh_stats_table$stat_name[i], mh_stats_table$total[i], mh_stats_table$male[i],mh_stats_table$female[i], mh_stats_table$dif[i])
    
  }else if(mh_stats_table$stat[i]=="Di_condit_mean"){
    new=paste0( 
      sprintf(" \\hspace{3mm} N   & %d & %d & %d & \\\\  \n", N_mh_stat_panelA[1], N_mh_stat_panelA[2], N_mh_stat_panelA[3]),
      "\\midrule  \n ",
      "\\multicolumn{5}{l}{\\textbf{Behavioral Assessment Subsample $(Q_i=1)$}}\\\\\n",
      sprintf(" %s  & %.3f & %.3f & %.3f &  \\multirow{2}{*}{%s} \\\\ \n", mh_stats_table$stat_name[i], mh_stats_table$total[i], mh_stats_table$male[i],mh_stats_table$female[i], mh_stats_table$dif[i]))
              
  } else{
    new=sprintf(" %s   (%.3f) & (%.3f) & (%.3f) &  \\\\ \n", mh_stats_table$stat_name[i], mh_stats_table$total[i], mh_stats_table$male[i],mh_stats_table$female[i])
    
  }
  
  tab_hold=paste0(tab_hold, new, "")
}

# add header and footer and save 
header=paste("\\begin{tabular}{lcccl}\n",
             "\\toprule \n",
             " & Total & Male & Female & Difference \\\\\n",
             "\\midrule \n",
             "\\textbf{Full Sample} & & & &\\\\\n")
footer=paste(sprintf(" \\hspace{3mm} N   & %d & %d & %d & \\\\  \n", N_mh_stat_panelB[1], N_mh_stat_panelB[2], N_mh_stat_panelB[3]),
             "\\bottomrule\n",
             "\\end{tabular}\n")

#save the table 
save_tex_table(header=header, footer=footer, body=tab_hold, file_name = "tab_4")


## clean up 
rm( mh_stats, mh_stats_table, test_result)
rm(group0, group1, group2)
rm(N_mh_stat_panelA, N_mh_stat_panelB, sds, means, var)
rm(tab_hold, i, new, footer, header)



#################################################
#Step 4: Reduced Form Regressions and table (table 3)
#################################################

#select variables needed from est_dat
rf_dat=est_dat%>%
  select(pat_id, Di, male, starts_with("year"),
         age_mean, Med, Com, white, hispanic,
         behav, well, numdocs, numapt, psych_doc, 
         starts_with("mh_dx_other"), birth_year)%>%
  select(-c(ends_with("uptoQ1")))

#year should be indicators for year fe
rf_dat=rf_dat%>%
     mutate(across(year_2014:year_2017, ~ as.integer(.x>0)))
         
#define indicator for any other mh dx but not ADHD
rf_dat=rf_dat%>%
  mutate(mh_dx_other=ifelse(((mh_dx_other_external==1 | mh_dx_other_internal==1) & Di==0), 1, 0))

#define age, age^2, and birth year fixed effects 
rf_dat=rf_dat%>%
  rename(age=age_mean)%>%
  mutate(age2=age^2)%>%
  dummy_cols(select_columns = "birth_year")


#######
#run model 1 (just male indicator)
rf1 = lm(Di ~ male , data = rf_dat)
rf_stats1=summary(rf1)$adj.r.squared
rf_stats1=c(rf_stats1, coeftest(rf1, vcov = vcovHC(rf1, type = "HC1"))[
  "male", c("Estimate", "Std. Error", "Pr(>|t|)")])

#run model 2 (add in demographic vars: age, insurance, race/ethnicity, year of birth fe)
rf2 = lm(Di ~ male  +
           age +age2 + Med + Com + white + hispanic + as.factor(birth_year), data = rf_dat)
rf_stats2=summary(rf2)$adj.r.squared
rf_stats2=c(rf_stats2, coeftest(rf2, vcov = vcovHC(rf2, type = "HC1"))[
  "male", c("Estimate", "Std. Error", "Pr(>|t|)")])


#run model 3 (add in hc vars: numdocs, numapt, well, year fes)
rf3 = lm(Di ~ male +
           age +age2 + Med + Com + white + hispanic + as.factor(birth_year)+
           year_2014 + year_2015 + year_2016 + year_2017+
           numdocs + numapt  + well 
           , data = rf_dat)
rf_stats3=summary(rf3)$adj.r.squared
rf_stats3=c(rf_stats3, coeftest(rf3, vcov = vcovHC(rf3, type = "HC1"))[
  "male", c("Estimate", "Std. Error", "Pr(>|t|)")])


#run model 3 (add in mental health hc vars: psych_dpc, behav, mh_dx_other)
rf4 = lm(Di ~ male +
           age +age2 + Med + Com + white + hispanic + as.factor(birth_year)+
           year_2014 + year_2015 + year_2016 + year_2017+
           numdocs + numapt  + well  +
           psych_doc  + mh_dx_other +behav
         , data = rf_dat)
rf_stats4=summary(rf4)$adj.r.squared
rf_stats4=c(rf_stats4, coeftest(rf4, vcov = vcovHC(rf4, type = "HC1"))[
  "male", c("Estimate", "Std. Error", "Pr(>|t|)")])

#######
#output table
#first need stars on coef 
male_coef=round(c(rf_stats1[2], rf_stats2[2], rf_stats3[2], rf_stats4[2]),3)
male_sig=c(rf_stats1[4], rf_stats2[4], rf_stats3[4], rf_stats4[4])
for(i in 1:length(male_sig)){
  if(male_sig[i]<.01){
    male_coef[i]=paste0(male_coef[i],"***","")
  }else if(male_sig[i]<.05){
    male_coef[i]=paste0(male_coef[i],"**","")
  } else if(male_sig[i]<.1){
    male_coef[i]=paste0(male_coef[i],"*","")
  }else
    male_coef[i]=paste(male_coef[i],"","")
}

#get inside of table, and header/footer
body= paste(sprintf(" \\textbf{Male}  & %s & %s & %s & %s \\\\ \n", male_coef[1],  male_coef[2], male_coef[3], male_coef[4]),
            sprintf("  & (%.3f) & (%.3f) & (%.3f) & (%.3f) \\\\ \n", rf_stats1[3], rf_stats2[3], rf_stats3[3], rf_stats4[3]),
            "\\midrule \n",
            "\\multicolumn{3}{l}{\\textit{Added Patient Observables:}}\\\\ \n",
            "\\hspace{3mm} Demographics & N & Y & Y & Y \\\\ \n",
            "\\hspace{3mm} General Healthcare Utilization & N & N & Y & Y \\\\\n",
            "\\hspace{3mm} Mental Healthcare Utilization & N & N & N & Y \\\\\n",
            "\\midrule \n",
            sprintf(" Adj. R-squared  & %.4f & %.4f & %.4f & %.4f \\\\ \n", rf_stats1[1],  rf_stats2[1], rf_stats3[1], rf_stats4[1]),
            sprintf("N & %d & %d & %d & %d \\\\\n", nrow(rf_dat), nrow(rf_dat), nrow(rf_dat), nrow(rf_dat)))

header=paste("\\begin{tabular}{lcccc}\n",
               "\\toprule \n",
               " & (1) & (2) & (3) & (4) \\\\\n",
               "\\midrule \n")
footer=paste("\\bottomrule \n",
               "\\end{tabular}")

#save the table 
save_tex_table(header=header, footer=footer, body=body, file_name = "tab_3")


#clean up 
rm(rf_dat, rf_stats1, rf_stats2, rf_stats3, rf_stats4,
    male_sig, male_coef)
rm(rf1, rf2, rf3, rf4)
rm(body, header, footer, i)



#################################################
#Step 5: Main Summary stat table (table 2)
#################################################

#vars for main table: Dx, male, age, white, hispanic, medicaid, #docs, #appt, #appt not pcp, #years in sample (N patients) 

#select variables needed from est_dat
sum_stat_dat=est_dat%>%
  select(Di, male, 
          age_mean, white, hispanic, Med, 
         numdocs, numapt, numapt_notpcp_first, 
         starts_with("year"))%>%
  select(-ends_with("uptoQ1"))

#define number of years in sample, male dx, female dx

sum_stat_dat=sum_stat_dat%>%
  mutate(across(year_2014:year_2017, ~ as.integer(.x>0)))%>%
  mutate(numyears=year_2014+year_2015+year_2016+year_2017,
         maledx=ifelse(male==1, Di, NA), 
         femaledx=ifelse(male==0, Di, NA))

#rename the vars with "_" in them  
sum_stat_dat=sum_stat_dat%>%
  rename(age=age_mean, 
         numaptnotpcp=numapt_notpcp_first)

#get mean, sd, min, max
sum_stat_dat=sum_stat_dat%>%
  select(Di, maledx, femaledx, 
         male, age, white, hispanic, Med,
         numdocs, numapt, numaptnotpcp, numyears)%>%
  summarise(across(everything(), list(
    mean = ~mean(., na.rm = TRUE), 
    sd = ~sd(., na.rm = TRUE),
    min = ~min(., na.rm = TRUE), 
    max = ~max(., na.rm = TRUE)
  )))

#reformat stats to columns and vars to rows
sum_stat_dat=sum_stat_dat%>%
  pivot_longer(cols = everything(), 
                          names_to = c("variable", "statistic"),
                          names_sep = "_")%>%
  pivot_wider(names_from = "statistic",
              values_from = "value")

### Format the table 
#start by renaming the variables 
sum_stat_dat=sum_stat_dat%>%
  mutate(var_name=case_when(variable=="Di" ~ "ADHD Dx.",
                            variable=="maledx" ~ "\\hspace{3mm} Male Dx.",
                            variable=="femaledx" ~ "\\hspace{3mm} Female Dx.",
                            variable=="male" ~ "Male",
                            variable=="age" ~ "Age",
                            variable=="white" ~ "White",
                            variable=="hispanic" ~ "Hispanic",
                            variable=="Med" ~ "Medicaid",
                            variable=="numdocs" ~ "\\# of Physicians",
                            variable=="numapt" ~ "\\# of Appt.",
                            variable=="numaptnotpcp" ~ "\\# of Appt. (not IPCP)",
                            variable=="numyears" ~ "\\# Yrs. in Sample"))

#produce the inside of the table 
tab_hold=""
for(i in 1:nrow(sum_stat_dat)){
  if(sum_stat_dat$variable[i] %in% c("male","numdocs")){
    new=paste("\\addlinespace \n", 
              sprintf(" %s & %.3f & %.3f & %.0f & %.0f  \\\\\n ", 
                sum_stat_dat$var_name[i], sum_stat_dat$mean[i], sum_stat_dat$sd[i], sum_stat_dat$min[i], sum_stat_dat$max[i]))
  }else{
  new=sprintf(" %s & %.3f & %.3f & %.0f & %.0f  \\\\\n ", 
              sum_stat_dat$var_name[i], sum_stat_dat$mean[i], sum_stat_dat$sd[i], sum_stat_dat$min[i], sum_stat_dat$max[i])
  }
  
  tab_hold=paste0(tab_hold, new, "")
}

              
footer=paste("\\midrule \n ",
               sprintf("N Patients  &  %d &  &   & \\\\\n", nrow(est_dat)),
               "\\bottomrule \n",
               "\\end{tabular}")


#add header and save table 
header=paste("\\begin{tabular}{lcccc}\n",
               "\\toprule \n",
               " & Mean & Std. Dev. & Min & Max \\\\\n",
               "\\midrule \n")

save_tex_table(header=header, footer=footer, body=tab_hold, file_name = "tab_2")

#clean up 
rm(tab_hold, new, sum_stat_dat)
rm(header, footer, i)



#################################################
#Step 6: Appendix Summary stat table - male.female comparisons (table A1)
#################################################
#full sample and behavioral sample 
  #age, non-hispanic white, non-white hispanic, medicaid, private ins, N 

#define and select variables from est_dat
compare_dat=est_dat%>%
  mutate(white_nh=as.integer(white==1 & hispanic==0),
         nw_hisp=as.integer(white==0 & hispanic==1))%>%
  select(Qi, male, age_mean, white_nh, nw_hisp, Med, Com)

compare_tab=data.frame(stat = character(),  male=numeric(), female=numeric(), 
                          difference=numeric(), pval_dif=numeric(),
                          stringsAsFactors = F)

#loop overall columns in compare_dat except male and Qi
for (var in names(compare_dat)){
  if(!var %in% c("male", "Qi")){
    group1 = compare_dat[[var]][compare_dat[["male"]] == 1]
    group2 = compare_dat[[var]][compare_dat[["male"]] == 0]

    # Perform t-test
    test_result = t.test(group1, group2, var.equal = TRUE)
    
    # Store the variable name, means, sd, and ttest p-value in the results dataframe
    compare_tab = rbind(compare_tab, data.frame(stat = paste0(var, "_full"), 
                                                        male=test_result$estimate[1],
                                                       female=test_result$estimate[2],
                                                       difference= test_result$estimate[1]-test_result$estimate[2],
                                                       pval_dif = test_result$p.value))
  }
}
#get N 
N_full=c(nrow(compare_dat%>%filter(male==1)), nrow(compare_dat%>%filter(male==0)))

#do the same for those with Qi=1
compare_dat=compare_dat%>%filter(Qi==1)
for (var in names(compare_dat)){
  if(!var %in% c("male", "Qi")){
    group1 = compare_dat[[var]][compare_dat[["male"]] == 1]
    group2 = compare_dat[[var]][compare_dat[["male"]] == 0]
    
    # Perform t-test
    test_result = t.test(group1, group2, var.equal = TRUE)
    
    # Store the variable name, means, sd, and ttest p-value in the results dataframe
    compare_tab = rbind(compare_tab, data.frame(stat = paste0(var, "_sub"), 
                                                 male=test_result$estimate[1],
                                                 female=test_result$estimate[2],
                                                 difference= test_result$estimate[1]-test_result$estimate[2],
                                                 pval_dif = test_result$p.value))
  }
}

#get N 
N_sub=c(nrow(compare_dat%>%filter(male==1)), nrow(compare_dat%>%filter(male==0)))


#add dif_star variable 
compare_tab=compare_tab%>%
  mutate(stars=case_when(pval_dif<.01 ~ "***",
                         pval_dif<.05 ~ "**",
                         pval_dif<.1 ~ "*",
                         T ~ ""))%>%
  mutate(dif_star=paste0(round(difference, 3), stars))

#rename variables and keep only needed columns
compare_tab=compare_tab%>%
  select(stat, male, female, dif_star)%>%
  mutate(stat_name= sub("(_full|_sub)", "", stat))%>%
  mutate(stat_name=case_when(stat_name=="age_mean" ~ "Age",
                             stat_name=="white_nh" ~ "Non-Hispanic White",
                             stat_name=="nw_hisp" ~ "Non-White Hispanic",
                             stat_name=="Med" ~ "Medicaid",
                             stat_name=="Com" ~ "Private Ins."))%>%
  select(stat, male, female, dif_star, stat_name)


##create table 
#produce the inside of the table 
tab_hold=""
for(i in 1:nrow(compare_tab)){
  if(compare_tab$stat[i]=="age_mean_sub"){
    
    new=paste(sprintf("N & %d & %d & \\\\\n", N_full[1], N_full[2]),
                      "\\multicolumn{3}{l}{\\textbf{Behavioral Assessment Sample}}\\\\\n",
              sprintf(" %s  & %.3f & %.3f & %s \\\\ \n", compare_tab$stat_name[i], compare_tab$male[i], compare_tab$female[i], compare_tab$dif_star[i]))
  }else{
    new=sprintf(" %s  & %.3f & %.3f & %s \\\\ \n", compare_tab$stat_name[i], compare_tab$male[i], compare_tab$female[i], compare_tab$dif_star[i])
    
  }
  
  tab_hold=paste0(tab_hold, new, "")
}

#add header and footer 
header=paste("\\begin{tabular}{lccc}\n",
                    "\\toprule \n",
                    "  & Male & Female & Difference \\\\\n",
                    "\\midrule \n",
                    "\\textbf{Full Sample} & & &\\\\\n")

footer=paste(sprintf(sprintf("N & %d & %d & \\\\\n", N_sub[1], N_sub[2])),
                    "\\bottomrule\n",
                    "\\end{tabular}\n")

#save table 
save_tex_table(header=header, footer=footer, body=tab_hold, file_name = "tab_a1")

## clean up 
rm(compare_dat, compare_tab)
rm(group1, group2, test_result)
rm(N_full, N_sub, tab_hold)
rm(var, new, header, footer, i)



#################################################
#Step 7: Appendix table Other Mental Health Dx (table A8)
#################################################
#full sample and behavioral sample 
#ADHD, other external, other internal 

#define and select variables from est_dat
other_dx_dat=est_dat%>%
  select(Qi, male, Di, mh_dx_other_external, mh_dx_other_internal)
 
other_dx_tab=data.frame(stat = character(),  total=numeric(), male=numeric(), 
                       female=numeric(),
                       stringsAsFactors = F)

#loop overall columns in compare_dat except male and Qi
for (var in names(other_dx_dat)){
  if(!var %in% c("male", "Qi")){
    total= other_dx_dat[[var]]
    male = other_dx_dat[[var]][other_dx_dat[["male"]] == 1]
    female = other_dx_dat[[var]][other_dx_dat[["male"]] == 0]
    
    # Store the variable name, means, sd, and ttest p-value in the results dataframe
    other_dx_tab = rbind(other_dx_tab, data.frame(stat = paste0(var, "_full"), 
                                                total=mean(total),
                                                male=mean(male),
                                                female=mean(female)))
  }
}

#do the same for those with Qi=1
other_dx_dat=other_dx_dat%>%filter(Qi==1)
for (var in names(other_dx_dat)){
  if(!var %in% c("male", "Qi")){
    total= other_dx_dat[[var]]
    male = other_dx_dat[[var]][other_dx_dat[["male"]] == 1]
    female = other_dx_dat[[var]][other_dx_dat[["male"]] == 0]
    
    # Store the variable name, means, sd, and ttest p-value in the results dataframe
    other_dx_tab = rbind(other_dx_tab, data.frame(stat = paste0(var, "_sub"), 
                                                   total=mean(total),
                                                   male=mean(male),
                                                   female=mean(female)))
  }
}


#rename variables and keep only needed columns
other_dx_tab=other_dx_tab%>%
  select(stat, total, male, female)%>%
  mutate(stat_name= sub("(_full|_sub)", "", stat))%>%
  mutate(stat_name=case_when(stat_name=="Di" ~ "\\hspace{3mm} ADHD",
                             stat_name=="mh_dx_other_external" ~ "\\hspace{3mm} Other External",
                             stat_name=="mh_dx_other_internal" ~ "\\hspace{3mm} Other Internal"))%>%
  select(stat, total, male, female, stat_name)


##create table 
#produce the inside of the table 
tab_hold=""
for(i in 1:nrow(other_dx_tab)){
  if(other_dx_tab$stat[i]=="Di_sub"){
    new=paste("\\multicolumn{4}{l}{\\textbf{Behavioral Assessment Subsample $(Q_i=1)$}}\\\\\n",
              sprintf(" %s  & %.3f & %.3f & %.3f \\\\ \n", other_dx_tab$stat_name[i], other_dx_tab$total[i], other_dx_tab$male[i], other_dx_tab$female[i]))
  }else{
    new=sprintf(" %s  & %.3f & %.3f & %.3f \\\\ \n", other_dx_tab$stat_name[i], other_dx_tab$total[i], other_dx_tab$male[i], other_dx_tab$female[i])
  }
  tab_hold=paste0(tab_hold, new, "")
}

#add header and footer 
header=paste("\\begin{tabular}{lccc}\n",
             "\\toprule \n",
             "  & Total & Male & Female \\\\\n",
             "\\midrule \n",
             "\\textbf{Full Sample} & & &\\\\\n")


footer=paste("\\bottomrule\n",
             "\\end{tabular}\n")

#save table 
save_tex_table(header=header, footer=footer, body=tab_hold, file_name = "tab_a8")

## clean up 
rm(other_dx_dat, other_dx_tab)
rm(total, male, female)
rm(var, new, header, footer, i, tab_hold)


#################################################
#Step 8: Note Length distribution figure (fig c2)
#################################################
nl_dat= est_dat%>%
  filter(Qi==1)%>%
  select(male, Di, note_length)

nl_dat = nl_dat %>%
  mutate(gender = factor(male, levels = c(0, 1), labels = c("Female", "Male")),
         Di = factor(Di, levels = c(0, 1), labels = c("No ADHD Dx", "ADHD Dx")))

vline_data_m = data.frame(
  Di = c("No ADHD Dx", "ADHD Dx"),
  xint = c(mean(nl_dat$note_length[which(nl_dat$Di=="No ADHD Dx" & nl_dat$male==1)]),
           mean(nl_dat$note_length[which(nl_dat$Di=="ADHD Dx" & nl_dat$male==1)]))
)
vline_data_f = data.frame(
  Di = c("No ADHD Dx", "ADHD Dx"),
  xint = c(mean(nl_dat$note_length[which(nl_dat$Di=="No ADHD Dx" & nl_dat$male==0)]),
           mean(nl_dat$note_length[which(nl_dat$Di=="ADHD Dx" & nl_dat$male==0)]))
)
# Plot
note_length_dist = ggplot(nl_dat, aes(x = note_length, fill = gender)) +
  geom_histogram(data = subset(nl_dat, gender == "Male"), 
                 position = "identity", alpha = 0.7, bins = 30) +
  geom_histogram(data = subset(nl_dat, gender == "Female"), 
                 position = "identity", alpha = 0.7, bins = 30) +
  geom_vline(data = vline_data_f, aes(xintercept = xint), 
             linetype = "dashed", color = "orange2", linewidth = 1) +
  geom_vline(data = vline_data_m, aes(xintercept = xint), 
             linetype = "dashed", color = "cyan4", linewidth = 1) +
  facet_wrap(~ Di) +
  labs(title = "",
       x = "Note Length",
       y = "Count",
       fill = "Gender") +
  scale_fill_manual(values = c("Female" = "orange2", "Male" = "cyan4")) +
  theme_classic() +
  theme(
    strip.text = element_text(size = 14),  # Facet titles
    legend.position = "bottom",                           # Legend below plot
    axis.title = element_text(size = 14),                 # X and Y axis titles
    legend.title = element_text(size = 14),                # Legend title
    legend.text = element_text(size = 14)  
  )


#note_length_dist
ggsave(plot = note_length_dist,
       filename = file.path("..", "output", "figures", "fig_c2.png"),
       width = 8, height = 6, dpi = 300)

#clean up 
rm(nl_dat, note_length_dist, vline_data_f, vline_data_m)


#################################################
#Step 9: XI correlation tables (C2 and C3)
#################################################
xi_cor_dat=est_dat%>%
  filter(Qi==1)%>%
  select(starts_with("xi_"), mult_assessments)


#start with c2 
xi_cor_tab1= xi_cor_dat%>%select(xi_01, xi_30, xi_60, xi_all)%>%cor()
xi_cor_tab1= round(xi_cor_tab1, 3)

xi_corr_table = paste(
  "\\begin{tabular}{lcccc}",
  "\\toprule",
  "         & $<=$ Initial Dx  & $<=$ 30 days post & $<=$ 60 days post & All Visits  \\\\",
  "\\midrule",
  sprintf(" $<=$ Initial Dx  & %.3f &  & &  \\\\",
          xi_cor_tab1[1,1]),
  sprintf(" $<=$ 30 days post    & %.3f & %.3f &  &  \\\\",
          xi_cor_tab1[1,2],  xi_cor_tab1[2,2]),
  sprintf(" $<=$ 60 days post   & %.3f & %.3f & %.3f &  \\\\",
          xi_cor_tab1[1,3],  xi_cor_tab1[2,3],  xi_cor_tab1[3,3]),
  sprintf("  All Visits  & %.3f & %.3f & %.3f & %.3f \\\\",
          xi_cor_tab1[1,4],  xi_cor_tab1[2,4],  xi_cor_tab1[3,4],  xi_cor_tab1[4,4]),
  "\\bottomrule",
  "\\end{tabular}",
  sep = "\n"
)

#save 
write(xi_corr_table, file = file.path("..", "output", "tables", "tab_c2.txt"))


#now do c3
xi_cor_tab2= xi_cor_dat%>%filter(mult_assessments==1)%>%
                                   select(xi_01, xi_30, xi_60)%>%cor()
xi_cor_tab2= round(xi_cor_tab2, 3)

xi_corr_table = paste(
  "\\begin{tabular}{lccc}",
  "\\toprule",
  "         & $<=$ Initial Dx  & $<=$ 30 days post & $<=$ 60 days post   \\\\",
  "\\midrule",
  sprintf(" $<=$ Initial Dx  & %.3f &  &   \\\\",
          xi_cor_tab2[1,1]),
  sprintf(" $<=$ 30 days post    & %.3f & %.3f &   \\\\",
          xi_cor_tab2[1,2],  xi_cor_tab2[2,2]),
  sprintf(" $<=$ 60 days post   & %.3f & %.3f & %.3f   \\\\",
          xi_cor_tab2[1,3],  xi_cor_tab2[2,3],  xi_cor_tab2[3,3]),
  "\\bottomrule",
  "\\end{tabular}",
  sep = "\n"
)

#save 
write(xi_corr_table, file = file.path("..", "output", "tables", "tab_c3.txt"))

## clean up 
rm(xi_cor_dat, xi_cor_tab1, xi_cor_tab2)
rm(xi_corr_table)


#END OF SCRIPT
